import pandas as pd
import torch

from ModularUtils.FunctionsConstant import asKey, getdoKey, get_dataset, save_datasets
from ModularUtils.FunctionsDistribution import get_joint_distributions_from_samples, calculate_TVD, calculate_KL, \
    compare_conditionals_within
from ModularUtils.ControllerConstants import map_dictfill_to_discrete, get_multiple_labels_fill
from ModularUtils.ControllerModel import get_generated_labels, get_fake_distribution
from ModularUtils.DigitImageGeneration.mnist_image_generation import plot_trained_digits
from ModularUtils.FrontBackDoorCalculation import estiamte_ate_backdoor_direct
from ModularUtils.FunctionsTraining import get_training_variables, save_results

def cond_vs_intv(Exp, U0,D,C):
    # ---------Calculates P(C|do(D)) with backdoor -----------
    cur_data= torch.cat([U0,D,C],1).cpu().numpy()
    # div = np.prod([Exp.label_dim[lb] for lb in ['D']])
    cond_prob = compare_conditionals_within(Exp, cur_data[:, 1:3], ['genC'], ['D'], ['D', 'genC'])
    print("Conditional:",cond_prob)

    # cur_data= np.concatenate([dataset.cpu().numpy(), midC],axis=1)
    px = pd.DataFrame(cur_data)
    px = px.rename(columns={0: 'U0', 1: 'D', 2: 'genC'})
    bd_dict = estiamte_ate_backdoor_direct(Exp, px, 'D', 'genC', ['U0'])
    print("backdoor:")
    for dict in bd_dict:
        print(dict, bd_dict[dict])


def trueMediatorEvaluation(Exp, cur_hnodes, label_generators, dataset_dict, tvd_diff, kl_diff):
    for gen in label_generators:
        label_generators[gen].eval()

    with torch.no_grad():
        # observational tvd for each mechanisms so that we can notice that mechanism learning

        feat = "feature"
        all_generated_labels={}
        all_real_labels={}


        # for hn, cur_mechs in cur_hnodes.items():
        # for compare_Var in [["C"], ["D", "C"]]:


        for query in Exp.interv_queries:

            for key in query["intervs"]:
                compare_Var= query["obs"]
                # for interv_no, key in enumerate(Exp.Data_intervs):
                intv_key = asKey(key)
                query_str = getdoKey(compare_Var, dict(intv_key))

                if key=={}:
                    # continue

                    if len(compare_Var)==0:
                        continue

                    _, _, _, graph_label_vars = get_training_variables(Exp, Exp.label_names, 0, key)
                    obs_indices = [graph_label_vars.index(lb) for lb in compare_Var]
                    current_real_label = []
                    if intv_key in dataset_dict:
                        current_real_label = dataset_dict[intv_key]["obs"][:, obs_indices].type(torch.LongTensor).view(-1,len(obs_indices)).to(Exp.DEVICE)


                    fake_dist_dict= get_fake_distribution(Exp, label_generators, intv_key, compare_Var)
                    dataset_dist_dict = get_joint_distributions_from_samples(Exp, compare_Var,
                                                                             current_real_label.detach().cpu().numpy().astype(
                                                                                 int), "feature")

                    # true_dist_dict = get_intv_dist(Exp, compare_Var, dict(intv_key), query_str, load_prev=False)
                    # dataset_dist_dict=true_dist_dict

                    obs_tvd = calculate_TVD(fake_dist_dict, dataset_dist_dict, doPrint=False)
                    obs_kl= calculate_KL(fake_dist_dict, dataset_dist_dict, doPrint=False)

                    if query_str in tvd_diff:
                        tvd_diff[query_str].append(round(obs_tvd, 4))  # todo: fix it
                        kl_diff[query_str].append(round(obs_kl, 4))

                    # print("")

                # else:
                interv_no=0
                data_input= dataset_dict[asKey({})]["obs"]
                all_compare_Var, compare_Var, intervened_Var, real_labels_vars = get_training_variables(Exp,compare_Var,interv_no, key)

                intv_tensor_dict={}
                for lbid, intv_lb in enumerate(intervened_Var):  # if no intervention then no looping
                    # index = [Exp.label_names.index(intv_lb)]
                    # parent_intv_label = data_input[:, index].type(torch.LongTensor).view(-1, 1).to(Exp.DEVICE) #for each intv parent

                    if intv_lb in Exp.image_labels:  # if the intervened variable is an image
                        obs_images = dataset_dict[asKey({})]["img"]
                        intv_parent_fill = obs_images
                        isClassifier = True
                    else:
                        ind = real_labels_vars.index(intv_lb)
                        parent_intv_label = data_input[:, ind].type(torch.LongTensor).view(-1, 1).to(
                            Exp.DEVICE)  # for each intv parent
                        dims_list = [Exp.label_dim[intv_lb]]
                        intv_parent_fill = get_multiple_labels_fill(Exp, parent_intv_label, dims_list,
                                                                    isImage_labels=False)
                    intv_tensor_dict[intv_lb] = intv_parent_fill

                generated_labels_dict = get_generated_labels(Exp, label_generators, {}, {}, intv_tensor_dict,
                                                             real_labels_vars, Exp.Synthetic_Sample_Size)
                y_dims = sum([Exp.label_dim[lb] for lb in real_labels_vars])
                ret = list(generated_labels_dict.values())
                generated_labels_fill = torch.cat(ret, 1).view(-1, y_dims)

                generated_c = generated_labels_fill[:, 4:7]
                generated_c = map_dictfill_to_discrete(Exp, {'genC': generated_c}, compare_Var)
                fake_dist_dict = get_joint_distributions_from_samples(Exp, ['genC'], generated_c, "feature")

                real_c = data_input[:, 2].view(-1, 1)
                # real_c = map_dictfill_to_discrete(Exp, {'genC': real_c}, compare_Var)
                real_dist_dict = get_joint_distributions_from_samples(Exp, ['genC'], real_c.view(-1, 1).cpu(),
                                                                      "feature")

                print("xxx", real_dist_dict)
                print("-->", fake_dist_dict)

                real_D = data_input[:, 0].view(-1, 1)
                real_U0 = data_input[:, 1].view(-1, 1)
                cond_vs_intv(Exp, real_U0, real_D, torch.tensor(generated_c).to(Exp.DEVICE))


                # saving generatedC as dataset for another experiment
                intvno = 0
                label_save_dir = "/path_to_project/SAVED_EXPERIMENTS/imageMediator/preprocessed_dataset/" + "intv" + str(intvno)
                save_datasets(True, label_save_dir, "feature", {'medU0': real_U0.cpu().numpy()})
                save_datasets(True, label_save_dir, "feature", {'medD': real_D.cpu().numpy()})
                save_datasets(True, label_save_dir, "feature", {'medC': generated_c})


                # elif query["expr"]=="P(newC|do(D))":
                # else:
                ############ Intervention #########
                # dataset= dataset_dict[asKey({})]
                # intv_dict= {'D':dataset['obs']['D'], 'U0':dataset['obs']['U0'], 'I': dataset['img']['I']}
                # gen_C = get_generated_labels(Exp, label_generators, {}, {}, intv_dict, compare_Var,Exp.Synthetic_Sample_Size)['genC']
                # genC = map_dictfill_to_discrete(Exp, {'genC': gen_C}, ['genC'])
                # genC = torch.tensor(genC).to(Exp.DEVICE)
                # freq = torch.bincount(genC[:, 0].type(torch.LongTensor), minlength=2) / genC.shape[0]
                # print("fake freq", freq)
                # real_D = get_dataset(Exp, 'D', 0)
                # cond_vs_intv(Exp, dataset['U0'], real_D, genC)


                    # fake_dist_dict= get_fake_distribution(Exp, label_generators, key, compare_Var)
                    # print('fake intv dist_dict',fake_dist_dict)
                    # D = get_dataset(Exp, 'D', 0)
                    # U0 = get_dataset(Exp, 'U0', 0)
                    # newC = get_dataset(Exp, 'newC', 0)
                    # cur_data = torch.cat([U0, D, newC], 1).cpu().numpy()
                    # px = pd.DataFrame(cur_data)
                    # px = px.rename(columns={0: 'U0', 1: 'D', 2: 'C'})
                    # dataset_dist_dict = estiamte_ate_backdoor_direct(Exp, px, 'D', 'C', ['U0'])[list(key.values())[0]]
                    # dataset_dist_dict= {tuple([key]):val for key,val in dataset_dist_dict.items()}
                    # obs_tvd = calculate_TVD(fake_dist_dict, dataset_dist_dict, doPrint=False)
                    # obs_kl = calculate_KL(fake_dist_dict, dataset_dist_dict, doPrint=False)
                    #
                    # if query_str in tvd_diff:
                    #     tvd_diff[query_str].append(round(obs_tvd, 4))  # todo: fix it
                    #     kl_diff[query_str].append(round(obs_kl, 4))





        save_results(Exp, Exp.SAVED_PATH, all_generated_labels ,all_real_labels,
                     tvd_diff, kl_diff, Exp.G_avg_losses, Exp.D_avg_losses)



    for gen in label_generators:
        label_generators[gen].train()

    ll = -min(10, len(list(tvd_diff.values())[0]))
    # printing loss
    for dist in tvd_diff:
        print("###", dist, " loss%:",  [round(val, 4) for val in tvd_diff[dist][ll:]])
    print(Exp.SAVED_PATH)

    return tvd_diff , kl_diff




# Exp = Experiment("Exp1", set_imageMediator,
#                  new_experiment=False,
#                  features=["feature"],
#                  Data_intervs=[{}])
#
# plot_saved_results(Exp, None, [], epochs=200, delta=5)



